"""
@Time    :  2020/11/22 17:12
@Author  :  Sun_Z_Z  
@FileName:  eval_model.py
@Institution:   Scut214
"""
import os
import torch
import cv2
import numpy as np

from train_model import RedBuleNet

root = r'./'
# net = RedBuleNet()
net = torch.load('RedBuleNet-Adam-0.01.pt')
net.cuda()
net.eval()

for r, dirs, filenames in os.walk(root):
    for file in filenames:
        img_path = os.path.join(r, file)
        img = cv2.imread(img_path)
        input = img.transpose(2, 0, 1)
        input = torch.from_numpy(np.array([input])).type(torch.float)
        pre = net(input.cuda())
        ret = 1 if pre[0][0] < 0.5 else 0
        print(pre, 'ret:', ret)
        # print(pre)
        cv2.imshow('ret', img)
        cv2.waitKey(0)
        # break
